#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from weakref import ref as weak_ref_to
from typing import List, Union, Any
from typing_extensions import Self

import numpy as np

import vedo.file_io
import vedo.vtkclasses as vtki  # a wrapper for lazy imports

import vedo
from vedo.transformations import LinearTransform
from vedo.visual import CommonVisual, Actor3DHelper

__docformat__ = "google"

__doc__ = """
Submodule for managing groups of vedo objects

![](https://vedo.embl.es/images/basic/align4.png)
"""

__all__ = ["Group", "Assembly", "procrustes_alignment"]


#################################################
def procrustes_alignment(sources: List["vedo.Mesh"], rigid=False) -> "Assembly":
    """
    Return an `Assembly` of aligned source meshes with the `Procrustes` algorithm.
    The output `Assembly` is normalized in size.

    The `Procrustes` algorithm takes N set of points and aligns them in a least-squares sense
    to their mutual mean. The algorithm is iterated until convergence,
    as the mean must be recomputed after each alignment.

    The set of average points generated by the algorithm can be accessed with
    `algoutput.info['mean']` as a numpy array.

    Arguments:
        rigid : bool
            if `True` scaling is disabled.

    Examples:
        - [align4.py](https://github.com/marcomusy/vedo/tree/master/examples/basic/align4.py)

        ![](https://vedo.embl.es/images/basic/align4.png)
    """

    group = vtki.new("MultiBlockDataGroupFilter")
    for source in sources:
        if sources[0].npoints != source.npoints:
            vedo.logger.error("sources have different nr of points")
            raise RuntimeError()
        group.AddInputData(source.dataset)
    procrustes = vtki.new("ProcrustesAlignmentFilter")
    procrustes.StartFromCentroidOn()
    procrustes.SetInputConnection(group.GetOutputPort())
    if rigid:
        procrustes.GetLandmarkTransform().SetModeToRigidBody()
    procrustes.Update()

    acts = []
    for i, s in enumerate(sources):
        poly = procrustes.GetOutput().GetBlock(i)
        mesh = vedo.mesh.Mesh(poly)
        mesh.actor.SetProperty(s.actor.GetProperty())
        mesh.properties = s.actor.GetProperty()
        if hasattr(s, "name"):
            mesh.name = s.name
        acts.append(mesh)
    assem = Assembly(acts)
    assem.transform = procrustes.GetLandmarkTransform()
    assem.info["mean"] = vedo.utils.vtk2numpy(procrustes.GetMeanPoints().GetData())
    return assem


#################################################
class Group:
    """Form groups of generic objects (not necessarily meshes)."""

    def __init__(self, objects=()):
        """Form groups of generic objects (not necessarily meshes)."""

        self.objects = []

        if isinstance(objects, dict):
            for name in objects:
                objects[name].name = name
            objects = list(objects.values())
        elif vedo.utils.is_sequence(objects):
            self.objects = list(objects)


        self.actor = vtki.vtkPropAssembly()
        self.actor.retrieve_object = weak_ref_to(self)

        self.name = "Group"
        self.filename = ""
        self.trail = None
        self.trail_points = []
        self.trail_segment_size = 0
        self.trail_offset = None
        self.shadows = []
        self.info = {}
        self.rendered_at = set()
        self.scalarbar = None

        for a in vedo.utils.flatten(objects):
            if a:
                self.actor.AddPart(a.actor)

        self.actor.PickableOff()


    def __str__(self):
        """Print info about Group object."""
        module = self.__class__.__module__
        name = self.__class__.__name__
        out = vedo.printc(
            f"{module}.{name} at ({hex(id(self))})".ljust(75),
            bold=True, invert=True, return_string=True,
        )
        out += "\x1b[0m"
        if self.name:
            out += "name".ljust(14) + ": " + self.name
            if "legend" in self.info.keys() and self.info["legend"]:
                out+= f", legend='{self.info['legend']}'"
            out += "\n"

        n = len(self.unpack())
        out += "n. of objects".ljust(14) + ": " + str(n) + " "
        names = [a.name for a in self.unpack() if a.name]
        if names:
            out += str(names).replace("'","")[:56]
        return out.rstrip() + "\x1b[0m"

    def __iadd__(self, obj):
        """Add an object to the group."""
        if not vedo.utils.is_sequence(obj):
            obj = [obj]
        for a in obj:
            if a:
                try:
                    self.actor.AddPart(a)
                except TypeError:
                    self.actor.AddPart(a.actor)
                    self.objects.append(a)
        return self

    def __isub__(self, obj):
        """Remove an object to the group."""
        if not vedo.utils.is_sequence(obj):
            obj = [obj]
        for a in obj:
            if a:
                try:
                    self.actor.RemovePart(a)
                except TypeError:
                    self.actor.RemovePart(a.actor)
                    self.objects.append(a)
        return self
    
    def rename(self, name: str) -> "Group":
        """Set a new name for the Group object."""
        self.name = name
        return self

    def add(self, obj):
        """Add an object to the group."""
        self.__iadd__(obj)
        return self

    def remove(self, obj):
        """Remove an object to the group."""
        self.__isub__(obj)
        return self

    def _unpack(self):
        """Unpack the group into its elements"""
        elements = []
        self.InitPathTraversal()
        parts = self.GetParts()
        parts.InitTraversal()
        for i in range(parts.GetNumberOfItems()):
            ele = parts.GetItemAsObject(i)
            elements.append(ele)

        # gr.InitPathTraversal()
        # for _ in range(gr.GetNumberOfPaths()):
        #     path  = gr.GetNextPath()
        #     print([path])
        #     path.InitTraversal()
        #     for i in range(path.GetNumberOfItems()):
        #         a = path.GetItemAsObject(i).GetViewProp()
        #         print([a])

        return elements

    def clear(self) -> "Group":
        """Remove all parts"""
        for a in self._unpack():
            self.actor.RemovePart(a)
        self.objects = []
        return self

    def on(self) -> "Group":
        """Switch on visibility"""
        self.VisibilityOn()
        return self

    def off(self) -> "Group":
        """Switch off visibility"""
        self.VisibilityOff()
        return self

    def pickable(self, value=True) -> "Group":
        """The pickability property of the Group."""
        self.actor.SetPickable(value)
        return self

    def use_bounds(self, value=True) -> "Group":
        """Set the use bounds property of the Group."""
        self.actor.SetUseBounds(value)
        return self

    def print(self) -> "Group":
        """Print info about the object."""
        print(self)
        return self


#################################################
class Assembly(CommonVisual, Actor3DHelper):
    """
    Group many objects and treat them as a single new object.
    """

    def __init__(self, *meshs):
        """
        Group many objects and treat them as a single new object,
        keeping track of internal transformations.

        A file can be loaded by passing its name as a string.
        Format must be `.npy`.

        Examples:
            - [gyroscope1.py](https://github.com/marcomusy/vedo/tree/master/examples/simulations/gyroscope1.py)

            ![](https://vedo.embl.es/images/simulations/39766016-85c1c1d6-52e3-11e8-8575-d167b7ce5217.gif)
        """
        super().__init__()

        self.actor = vtki.vtkAssembly()
        self.actor.retrieve_object = weak_ref_to(self)

        self.name = "Assembly"
        self.filename = ""
        self.rendered_at = set()
        self.scalarbar = None
        self.info = {}
        self.time = 0

        self.transform = LinearTransform()

        # Init by filename
        if len(meshs) == 1 and isinstance(meshs[0], str):
            filename = vedo.file_io.download(meshs[0], verbose=False)
            data = np.load(filename, allow_pickle=True)
            try:
                # old format with a single object
                meshs = [vedo.file_io.from_numpy(dd) for dd in data]
            except TypeError:
                # new format with a dictionary
                data = data.item()
                meshs = []
                for ad in data["objects"][0]["parts"]:
                    obb = vedo.file_io.from_numpy(ad)
                    meshs.append(obb)
                self.transform = LinearTransform(data["objects"][0]["transform"])
                self.actor.SetPosition(self.transform.T.GetPosition())
                self.actor.SetOrientation(self.transform.T.GetOrientation())
                self.actor.SetScale(self.transform.T.GetScale())

        # Name and load from dictionary
        if len(meshs) == 1 and isinstance(meshs[0], dict):
            meshs = meshs[0]
            for name in meshs:
                meshs[name].name = name
            meshs = list(meshs.values())
        else:
            if len(meshs) == 1:
                meshs = meshs[0]
            else:
                meshs = vedo.utils.flatten(meshs)

        self.objects = [m for m in meshs if m]
        self.actors  = [m.actor for m in self.objects]

        scalarbars = []
        for a in self.actors:
            if isinstance(a, vtki.get_class("Prop3D")): # and a.GetNumberOfPoints():
                self.actor.AddPart(a)
            if hasattr(a, "scalarbar") and a.scalarbar is not None:
                scalarbars.append(a.scalarbar)

        if len(scalarbars) > 1:
            self.scalarbar = Group(scalarbars)
        elif len(scalarbars) == 1:
            self.scalarbar = scalarbars[0]

        self.pipeline = vedo.utils.OperationNode(
            "Assembly",
            parents=self.objects,
            comment=f"#meshes {len(self.objects)}",
            c="#f08080",
        )
        ##########################################

    def __str__(self):
        """Print info about Assembly object."""
        module = self.__class__.__module__
        cname = self.__class__.__name__
        out = vedo.printc(
            f"{module}.{cname} at ({hex(id(self))})".ljust(75),
            bold=True, invert=True, return_string=True,
        )
        out += "\x1b[0m"

        if self.name:
            out += "name".ljust(14) + ": " + self.name
            if "legend" in self.info.keys() and self.info["legend"]:
                out+= f", legend='{self.info['legend']}'"
            out += "\n"

        n = len(self.unpack())
        out += "n. of objects".ljust(14) + ": " + str(n) + " "
        names = np.unique([a.name for a in self.unpack() if a.name])
        if len(names)>0:
            out += str(names).replace("'","")[:56]
        out += "\n"

        pos = self.actor.GetPosition()
        out += "position".ljust(14) + ": " + str(pos) + "\n"

        bnds = self.actor.GetBounds()
        bx1, bx2 = vedo.utils.precision(bnds[0], 3), vedo.utils.precision(bnds[1], 3)
        by1, by2 = vedo.utils.precision(bnds[2], 3), vedo.utils.precision(bnds[3], 3)
        bz1, bz2 = vedo.utils.precision(bnds[4], 3), vedo.utils.precision(bnds[5], 3)
        out += "bounds".ljust(14) + ":"
        out += " x=(" + bx1 + ", " + bx2 + "),"
        out += " y=(" + by1 + ", " + by2 + "),"
        out += " z=(" + bz1 + ", " + bz2 + ")\n"

        if "Histogram1D" in cname:
            if self.title  != '': out += f"title".ljust(14) + ": " + f'{self.title}\n'
            if self.xtitle and self.xtitle != ' ': out += f"xtitle".ljust(14) + ": " + f'{self.xtitle}\n'
            if self.ytitle and self.ytitle != ' ': out += f"ytitle".ljust(14) + ": " + f'{self.ytitle}\n'
            out += f"entries".ljust(14) + ": " + f"{self.entries}\n"
            out += f"mean, mode".ljust(14) + ": " + f"{self.mean:.6f}, {self.mode:.6f}\n"
            out += f"std".ljust(14) + ": " + f"{self.std:.6f}"
        elif "Histogram2D" in cname:
            if self.title  != '': out += f"title".ljust(14) + ": " + f'{self.title}\n'
            if self.xtitle and self.xtitle != ' ': out += f"xtitle".ljust(14) + ": " + f'{self.xtitle}\n'
            if self.ytitle and self.ytitle != ' ': out += f"ytitle".ljust(14) + ": " + f'{self.ytitle}\n'
            out += f"entries".ljust(14) + ": " + f"{self.entries}\n"
            out += f"mean".ljust(14) + ": " + f"{vedo.utils.precision(self.mean, 6)}\n"
            out += f"std".ljust(14) + ": " + f"{vedo.utils.precision(self.std, 6)}"


        return out.rstrip() + "\x1b[0m"

    def _repr_html_(self):
        """
        HTML representation of the Assembly object for Jupyter Notebooks.

        Returns:
            HTML text with the image and some properties.
        """
        import io
        import base64
        from PIL import Image

        library_name = "vedo.assembly.Assembly"
        help_url = "https://vedo.embl.es/docs/vedo/assembly.html"

        arr = self.thumbnail(zoom=1.1, elevation=-60)

        im = Image.fromarray(arr)
        buffered = io.BytesIO()
        im.save(buffered, format="PNG", quality=100)
        encoded = base64.b64encode(buffered.getvalue()).decode("utf-8")
        url = "data:image/png;base64," + encoded
        image = f"<img src='{url}'></img>"

        # statisitics
        bounds = "<br/>".join(
            [
                vedo.utils.precision(min_x, 4) + " ... " + vedo.utils.precision(max_x, 4)
                for min_x, max_x in zip(self.bounds()[::2], self.bounds()[1::2])
            ]
        )

        help_text = ""
        if self.name:
            help_text += f"<b> {self.name}: &nbsp&nbsp</b>"
        help_text += '<b><a href="' + help_url + '" target="_blank">' + library_name + "</a></b>"
        if self.filename:
            dots = ""
            if len(self.filename) > 30:
                dots = "..."
            help_text += f"<br/><code><i>({dots}{self.filename[-30:]})</i></code>"

        allt = [
            "<table>",
            "<tr>",
            "<td>",
            image,
            "</td>",
            "<td style='text-align: center; vertical-align: center;'><br/>",
            help_text,
            "<table>",
            "<tr><td><b> nr. of objects </b></td><td>"
            + str(self.actor.GetNumberOfPaths())
            + "</td></tr>",
            "<tr><td><b> position </b></td><td>" + str(self.actor.GetPosition()) + "</td></tr>",
            "<tr><td><b> diagonal size </b></td><td>"
            + vedo.utils.precision(self.diagonal_size(), 5)
            + "</td></tr>",
            "<tr><td><b> bounds </b> <br/> (x/y/z) </td><td>" + str(bounds) + "</td></tr>",
            "</table>",
            "</table>",
        ]
        return "\n".join(allt)

    def __add__(self, obj):
        """
        Add an object to the assembly
        """
        if isinstance(getattr(obj, "actor", None), vtki.get_class("Prop3D")):

            self.objects.append(obj)
            self.actors.append(obj.actor)
            self.actor.AddPart(obj.actor)

            if hasattr(obj, "scalarbar") and obj.scalarbar is not None:
                if self.scalarbar is None:
                    self.scalarbar = obj.scalarbar
                    return self

                def unpack_group(scalarbar):
                    if isinstance(scalarbar, Group):
                        return scalarbar.unpack()
                    else:
                        return scalarbar

                if isinstance(self.scalarbar, Group):
                    self.scalarbar += unpack_group(obj.scalarbar)
                else:
                    self.scalarbar = Group([unpack_group(self.scalarbar), unpack_group(obj.scalarbar)])
            self.pipeline = vedo.utils.OperationNode("add mesh", parents=[self, obj], c="#f08080")
        return self

    def __isub__(self, obj):
        """
        Remove an object to the assembly.
        """
        if not vedo.utils.is_sequence(obj):
            obj = [obj]
        for a in obj:
            if a:
                try:
                    self.actor.RemovePart(a)
                    self.objects.remove(a)
                except TypeError:
                    self.actor.RemovePart(a.actor)
                    self.objects.remove(a)
        return self

    def rename(self, name: str) -> "Assembly":
        """Set a new name for the Assembly object."""
        self.name = name
        return self

    def add(self, obj):
        """
        Add an object to the assembly.
        """
        self.__add__(obj)
        return self

    def remove(self, obj):
        """
        Remove an object to the assembly.
        """
        self.__isub__(obj)
        return self

    def __contains__(self, obj):
        """Allows to use `in` to check if an object is in the `Assembly`."""
        return obj in self.objects

    def __getitem__(self, i):
        """Return i-th object."""
        if isinstance(i, int):
            return self.objects[i]
        elif isinstance(i, str):
            for m in self.objects:
                if i == m.name:
                    return m
        return None

    def __len__(self):
        """Return nr. of objects in the assembly."""
        return len(self.objects)

    def write(self, filename="assembly.npy") -> Self:
        """
        Write the object to file in `numpy` format (npy).
        """
        vedo.file_io.write(self, filename)
        return self

    # TODO ####
    # def propagate_transform(self):
    #     """Propagate the transformation to all parts."""
    #     # navigate the assembly and apply the transform to all parts
    #     # and reset position, orientation and scale of the assembly
    #     for i in range(self.actor.GetNumberOfPaths()):
    #         path = self.actor.GetPath(i)
    #         obj = path.GetLastNode().GetViewProp()
    #         obj.SetUserTransform(self.transform.T)
    #         obj.SetPosition(0, 0, 0)
    #         obj.SetOrientation(0, 0, 0)
    #         obj.SetScale(1, 1, 1)
    #     raise NotImplementedError()

    def unpack(self, i=None) -> Union[List["vedo.Mesh"], "vedo.Mesh"]:
        """Unpack the list of objects from a `Assembly`.

        If `i` is given, get `i-th` object from a `Assembly`.
        Input can be a string, in this case returns the first object
        whose name contains the given string.

        Examples:
            - [custom_axes4.py](https://github.com/marcomusy/vedo/tree/master/examples/pyplot/custom_axes4.py)
        """
        if i is None:
            return self.objects
        elif isinstance(i, int):
            return self.objects[i]
        elif isinstance(i, str):
            for m in self.objects:
                if i == m.name:
                    return m
        return []

    def recursive_unpack(self) -> List["vedo.Mesh"]:
        """Flatten out an Assembly."""

        def _genflatten(lst):
            if lst:
                ##
                if isinstance(lst[0], Assembly):
                    lst = lst[0].unpack()
                ##
                for elem in lst:
                    if isinstance(elem, Assembly):
                        apos = elem.actor.GetPosition()
                        asum = np.sum(apos)
                        for x in elem.unpack():
                            if asum:
                                yield x.clone().shift(apos)
                            else:
                                yield x
                    else:
                        yield elem

        return list(_genflatten([self]))

    def pickable(self, value=True) -> "Assembly":
        """Set/get the pickability property of an assembly and its elements"""
        self.actor.SetPickable(value)
        # set property to each element
        for elem in self.recursive_unpack():
            elem.pickable(value)
        return self

    def clone(self) -> "Assembly":
        """Make a clone copy of the object. Same as `copy()`."""
        newlist = []
        for a in self.objects:
            newlist.append(a.clone())
        return Assembly(newlist)

    def clone2d(self, pos="bottom-left", size=1, rotation=0, ontop=False, justify="bottom-left") -> Group:
        """
        Convert the `Assembly` into a `Group` of 2D objects.

        Arguments:
            pos : (list, str)
                Position in 2D, as a string or list (x,y).
                The center of the renderer is [-1,-1] while top-right is [1,1].
                Any combination of "center", "top", "bottom", "left" and "right" will work.
            size : (float)
                global scaling factor for the 2D object.
                The scaling is normalized to the x-range of the original object.
            rotation : (float)
                rotation angle in degrees.
            ontop : (bool)
                if `True` the now 2D object is rendered on top of the 3D scene.
            scale : (float)
                deprecated, use `size` instead.
            justify : (str)
                justification for the 2D object.

        Returns:
            `Group` object.
        """
        padding = 0.05
        x0, x1 = self.xbounds()
        y0, y1 = self.ybounds()
        pp = self.pos()
        x0 -= pp[0]
        x1 -= pp[0]
        y0 -= pp[1]
        y1 -= pp[1]

        # choose offset based on justification
        offset = [x0, y0]
        if "cent" in justify:
            offset = [(x0 + x1) / 2, (y0 + y1) / 2]
            if "right" in justify:
                offset[0] = x1
            if "left" in justify:
                offset[0] = x0
            if "top" in justify:
                offset[1] = y1
            if "bottom" in justify:
                offset[1] = y0
        elif "top" in justify:
            if "right" in justify:
                offset = [x1, y1]
            elif "left" in justify:
                offset = [x0, y1]
            else:
                raise ValueError(f"incomplete justify='{justify}'")
        elif "bottom" in justify:
            if "right" in justify:
                offset = [x1, y0]
            elif "left" in justify:
                offset = [x0, y0]
            else:
                raise ValueError(f"incomplete justify='{justify}'")

        # choose position
        if "cent" in pos:
            offset = [(x0 + x1) / 2, (y0 + y1) / 2]
            position = [0., 0.]
            if "right" in pos:
                offset[0] = x1
                position = [1 - padding, 0]
            if "left" in pos:
                offset[0] = x0
                position = [-1 + padding, 0]
            if "top" in pos:
                offset[1] = y1
                position = [0, 1 - padding]
            if "bottom" in pos:
                offset[1] = y0
                position = [0, -1 + padding]
        elif "top" in pos:
            if "right" in pos:
                offset = [x1, y1]
                position = [1 - padding, 1 - padding]
            elif "left" in pos:
                offset = [x0, y1]
                position = [-1 + padding, 1 - padding]
            else:
                raise ValueError(f"incomplete position pos='{pos}'")
        elif "bottom" in pos:
            if "right" in pos:
                offset = [x1, y0]
                position = [1 - padding, -1 + padding]
            elif "left" in pos:
                offset = [x0, y0]
                position = [-1 + padding, -1 + padding]
            else:
                raise ValueError(f"incomplete position pos='{pos}'")
        else:
            position = pos

        group = Group()

        bnds = self.bounds()
        cm = [(bnds[0] + bnds[1]) / 2, (bnds[2] + bnds[3]) / 2, 0]

        for a in self.recursive_unpack():
            if not isinstance(a, vedo.Points):
                continue
            if a.npoints == 0:
                continue

            s = size * 500 / (x1 - x0)
            if a.properties.GetRepresentation() == 1:
                # wireframe is not rendered correctly in 2d
                b = a.boundaries().lw(1).c(a.color(), a.alpha())
                if rotation:
                    b.rotate_z(rotation, around=cm)
                a2d = b.clone2d(size=s, offset=offset)
            else:
                if rotation:
                    a.rotate_z(rotation, around=cm)
                a2d = a.clone2d(size=s, offset=offset)
            a2d.pos(position).ontop(ontop)
            group += a2d

        try: # copy info from Histogram1D
            group.entries = self.entries
            group.frequencies = self.frequencies
            group.errors = self.errors
            group.edges = self.edges
            group.centers = self.centers
            group.mean = self.mean
            group.mode = self.mode
            group.std = self.std
        except AttributeError:
            pass

        group.name = self.name
        return group

    def copy(self) -> "Assembly":
        """Return a copy of the object. Alias of `clone()`."""
        return self.clone()
